{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "dc0e65bf",
   "metadata": {
    "origin_pos": 1
   },
   "source": [
    "# Parameter Initialization\n",
    "\n",
    "Now that we know how to access the parameters,\n",
    "let's look at how to initialize them properly.\n",
    "We discussed the need for proper initialization in :numref:`sec_numerical_stability`.\n",
    "The deep learning framework provides default random initializations to its layers.\n",
    "However, we often want to initialize our weights\n",
    "according to various other protocols. The framework provides most commonly\n",
    "used protocols, and also allows to create a custom initializer.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "57c14f80",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:43:09.192619Z",
     "iopub.status.busy": "2023-08-18T19:43:09.192194Z",
     "iopub.status.idle": "2023-08-18T19:43:11.079730Z",
     "shell.execute_reply": "2023-08-18T19:43:11.078721Z"
    },
    "origin_pos": 3,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "315e66e5",
   "metadata": {
    "origin_pos": 7,
    "tab": [
     "pytorch"
    ]
   },
   "source": [
    "By default, PyTorch initializes weight and bias matrices\n",
    "uniformly by drawing from a range that is computed according to the input and output dimension.\n",
    "PyTorch's `nn.init` module provides a variety\n",
    "of preset initialization methods.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d2a5e8cb",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:43:11.084761Z",
     "iopub.status.busy": "2023-08-18T19:43:11.083735Z",
     "iopub.status.idle": "2023-08-18T19:43:11.122261Z",
     "shell.execute_reply": "2023-08-18T19:43:11.121411Z"
    },
    "origin_pos": 11,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 1])"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net = nn.Sequential(nn.LazyLinear(8), nn.ReLU(), nn.LazyLinear(1))\n",
    "X = torch.rand(size=(2, 4))\n",
    "net(X).shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "18a9c8b0",
   "metadata": {
    "origin_pos": 14
   },
   "source": [
    "## [**Built-in Initialization**]\n",
    "\n",
    "Let's begin by calling on built-in initializers.\n",
    "The code below initializes all weight parameters\n",
    "as Gaussian random variables\n",
    "with standard deviation 0.01, while bias parameters are cleared to zero.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "6059e0fb",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:43:11.127859Z",
     "iopub.status.busy": "2023-08-18T19:43:11.127125Z",
     "iopub.status.idle": "2023-08-18T19:43:11.135507Z",
     "shell.execute_reply": "2023-08-18T19:43:11.134596Z"
    },
    "origin_pos": 16,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([-0.0129, -0.0007, -0.0033,  0.0276]), tensor(0.))"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def init_normal(module):\n",
    "    if type(module) == nn.Linear:\n",
    "        nn.init.normal_(module.weight, mean=0, std=0.01)\n",
    "        nn.init.zeros_(module.bias)\n",
    "\n",
    "net.apply(init_normal)\n",
    "net[0].weight.data[0], net[0].bias.data[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cb3e91fa",
   "metadata": {
    "origin_pos": 19
   },
   "source": [
    "We can also initialize all the parameters\n",
    "to a given constant value (say, 1).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d2007d64",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:43:11.138851Z",
     "iopub.status.busy": "2023-08-18T19:43:11.138302Z",
     "iopub.status.idle": "2023-08-18T19:43:11.145695Z",
     "shell.execute_reply": "2023-08-18T19:43:11.144862Z"
    },
    "origin_pos": 21,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([1., 1., 1., 1.]), tensor(0.))"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def init_constant(module):\n",
    "    if type(module) == nn.Linear:\n",
    "        nn.init.constant_(module.weight, 1)\n",
    "        nn.init.zeros_(module.bias)\n",
    "\n",
    "net.apply(init_constant)\n",
    "net[0].weight.data[0], net[0].bias.data[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "130ca37e",
   "metadata": {
    "origin_pos": 24
   },
   "source": [
    "[**We can also apply different initializers for certain blocks.**]\n",
    "For example, below we initialize the first layer\n",
    "with the Xavier initializer\n",
    "and initialize the second layer\n",
    "to a constant value of 42.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "4734e6eb",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:43:11.149040Z",
     "iopub.status.busy": "2023-08-18T19:43:11.148497Z",
     "iopub.status.idle": "2023-08-18T19:43:11.155752Z",
     "shell.execute_reply": "2023-08-18T19:43:11.154840Z"
    },
    "origin_pos": 26,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([-0.0974,  0.1707,  0.5840, -0.5032])\n",
      "tensor([[42., 42., 42., 42., 42., 42., 42., 42.]])\n"
     ]
    }
   ],
   "source": [
    "def init_xavier(module):\n",
    "    if type(module) == nn.Linear:\n",
    "        nn.init.xavier_uniform_(module.weight)\n",
    "\n",
    "def init_42(module):\n",
    "    if type(module) == nn.Linear:\n",
    "        nn.init.constant_(module.weight, 42)\n",
    "\n",
    "net[0].apply(init_xavier)\n",
    "net[2].apply(init_42)\n",
    "print(net[0].weight.data[0])\n",
    "print(net[2].weight.data)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "164af268",
   "metadata": {
    "origin_pos": 29
   },
   "source": [
    "### [**Custom Initialization**]\n",
    "\n",
    "Sometimes, the initialization methods we need\n",
    "are not provided by the deep learning framework.\n",
    "In the example below, we define an initializer\n",
    "for any weight parameter $w$ using the following strange distribution:\n",
    "\n",
    "$$\n",
    "\\begin{aligned}\n",
    "    w \\sim \\begin{cases}\n",
    "        U(5, 10) & \\textrm{ with probability } \\frac{1}{4} \\\\\n",
    "            0    & \\textrm{ with probability } \\frac{1}{2} \\\\\n",
    "        U(-10, -5) & \\textrm{ with probability } \\frac{1}{4}\n",
    "    \\end{cases}\n",
    "\\end{aligned}\n",
    "$$\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ac721402",
   "metadata": {
    "origin_pos": 31,
    "tab": [
     "pytorch"
    ]
   },
   "source": [
    "Again, we implement a `my_init` function to apply to `net`.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "334b9bed",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:43:11.159032Z",
     "iopub.status.busy": "2023-08-18T19:43:11.158501Z",
     "iopub.status.idle": "2023-08-18T19:43:11.166911Z",
     "shell.execute_reply": "2023-08-18T19:43:11.166067Z"
    },
    "origin_pos": 35,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Init weight torch.Size([8, 4])\n",
      "Init weight torch.Size([1, 8])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.0000, -7.6364, -0.0000, -6.1206],\n",
       "        [ 9.3516, -0.0000,  5.1208, -8.4003]], grad_fn=<SliceBackward0>)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def my_init(module):\n",
    "    if type(module) == nn.Linear:\n",
    "        print(\"Init\", *[(name, param.shape)\n",
    "                        for name, param in module.named_parameters()][0])\n",
    "        nn.init.uniform_(module.weight, -10, 10)\n",
    "        module.weight.data *= module.weight.data.abs() >= 5\n",
    "\n",
    "net.apply(my_init)\n",
    "net[0].weight[:2]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9f4c13b0",
   "metadata": {
    "origin_pos": 38,
    "tab": [
     "pytorch"
    ]
   },
   "source": [
    "Note that we always have the option\n",
    "of setting parameters directly.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "e38feecc",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:43:11.170212Z",
     "iopub.status.busy": "2023-08-18T19:43:11.169683Z",
     "iopub.status.idle": "2023-08-18T19:43:11.176291Z",
     "shell.execute_reply": "2023-08-18T19:43:11.175385Z"
    },
    "origin_pos": 41,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([42.0000, -6.6364,  1.0000, -5.1206])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net[0].weight.data[:] += 1\n",
    "net[0].weight.data[0, 0] = 42\n",
    "net[0].weight.data[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e812c7c4",
   "metadata": {
    "origin_pos": 43
   },
   "source": [
    "## Summary\n",
    "\n",
    "We can initialize parameters using built-in and custom initializers.\n",
    "\n",
    "## Exercises\n",
    "\n",
    "Look up the online documentation for more built-in initializers.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "480a164a",
   "metadata": {
    "origin_pos": 45,
    "tab": [
     "pytorch"
    ]
   },
   "source": [
    "[Discussions](https://discuss.d2l.ai/t/8090)\n"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  },
  "required_libs": []
 },
 "nbformat": 4,
 "nbformat_minor": 5
}